import torch
import numpy as np
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score, f1_score


def multiclass_acc(preds, truths):
  """
    Compute the multiclass accuracy w.r.t. groundtruth

    :param preds: Float array representing the predictions, dimension (N,)
    :param truths: Float/int array representing the groundtruth classes, dimension (N,)
    :return: Classification accuracy
    """
  return np.sum(np.round(preds) == np.round(truths)) / float(len(truths))


def weighted_accuracy(test_preds_emo, test_truth_emo):
  true_label = (test_truth_emo > 0)
  predicted_label = (test_preds_emo > 0)
  tp = float(np.sum((true_label == 1) & (predicted_label == 1)))
  tn = float(np.sum((true_label == 0) & (predicted_label == 0)))
  p = float(np.sum(true_label == 1))
  n = float(np.sum(true_label == 0))

  return (tp * (n / p) + tn) / (2 * n)


def eval_mosei_senti(results, truths, exclude_zero=False):
  test_preds = results.view(-1).cpu().detach().numpy()
  test_truth = truths.view(-1).cpu().detach().numpy()

  non_zeros = np.array(
      [i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)])

  test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
  test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
  test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
  test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)

  mae = np.mean(
      np.absolute(test_preds -
                  test_truth))  # Average L1 distance between preds and truths
  corr = np.corrcoef(test_preds, test_truth)[0][1]
  mult_a7 = multiclass_acc(test_preds_a7, test_truth_a7)
  mult_a5 = multiclass_acc(test_preds_a5, test_truth_a5)
  f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0),
                     average='weighted')
  binary_truth = (test_truth[non_zeros] > 0)
  binary_preds = (test_preds[non_zeros] > 0)
  acc_2 = accuracy_score(binary_truth, binary_preds)

  # print("MAE: ", mae)
  # print("Correlation Coefficient: ", corr)
  # print("mult_acc_7: ", mult_a7)
  # print("mult_acc_5: ", mult_a5)
  # print("F1 score: ", f_score)
  # print("Accuracy: ", acc_2)

  # print("-" * 50)

  result_list = {
      'mae': mae,
      'corr': corr,
      'a7': mult_a7,
      'a5': mult_a5,
      'f1': f_score,
      'a2': acc_2
  }
  return result_list


def eval_emotion(results, truths, dataset):
  if dataset == 'iemocap':
    emos = ["Neutral", "Happy", "Sad", "Angry"]
  elif dataset == 'mosei_full_emo':
    emos = ["Happy", "Sad", "Angry", "Surprise", "Disgust", "Fear"]
  else:
    raise RuntimeError(f'Unconfigured emotion dataset {dataset}')

  dim = len(emos)
  test_preds = results.view(-1, dim, 2).cpu().detach().numpy()
  test_truth = truths.view(-1, dim).cpu().detach().numpy()

  # record these results and return
  result_list = {'acc_Overall': 0., 'f1_Overall': 0.}

  for emo_ind in range(dim):
    # print(f"{emos[emo_ind]}: ")
    test_preds_i = np.argmax(test_preds[:, emo_ind], axis=1)
    test_truth_i = test_truth[:, emo_ind]
    f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
    acc = accuracy_score(test_truth_i, test_preds_i)
    # print("  - F1 Score: ", f1)
    # print("  - Accuracy: ", acc)

    # record
    result_list['acc_' + emos[emo_ind]] = acc
    result_list['f1_' + emos[emo_ind]] = f1
    result_list['acc_Overall'] += acc
    result_list['f1_Overall'] += f1
  result_list['acc_Overall'] /= dim
  result_list['f1_Overall'] /= dim

  return result_list